from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
from matplotlib import pyplot as plt


TAKE_HOME_MESSAGE = (
    "More media/information freedom is associated with greater transparency during COVID (smaller excess-vs-reported "
    "mortality gaps) and lower true health losses, consistent with information environments functioning as public "
    "health capacity—especially for vulnerable groups."
)


@dataclass(frozen=True)
class Spec:
    y: str
    y_label: str
    y_panel: str


def ensure_dirs(out_dir: Path) -> None:
    (out_dir / "tables").mkdir(parents=True, exist_ok=True)
    (out_dir / "figures").mkdir(parents=True, exist_ok=True)


def ols_country_cross_section(df: pd.DataFrame, spec: Spec, out_dir: Path) -> None:
    g = df.copy()
    g = g[g["media_freedom"].notna()].copy()

    g = (
        g.groupby(["iso_code", "continent", "location", "media_freedom"], as_index=False)
        .agg(
            y=(spec.y, "sum"),
            population=("population", "max"),
            gdp_per_capita=("gdp_per_capita", "max"),
            median_age=("median_age", "max"),
            hospital_beds_per_thousand=("hospital_beds_per_thousand", "max"),
            diabetes_prevalence=("diabetes_prevalence", "max"),
            human_development_index=("human_development_index", "max"),
        )
    )

    g["y_pm"] = (g["y"] / g["population"]) * 1_000_000
    g["log_gdp_pc"] = np.log(g["gdp_per_capita"])

    X_cols = [
        "media_freedom",
        "log_gdp_pc",
        "median_age",
        "hospital_beds_per_thousand",
        "diabetes_prevalence",
        "human_development_index",
    ]
    d = g[["y_pm"] + X_cols].dropna().copy()
    X = sm.add_constant(d[X_cols])
    model = sm.OLS(d["y_pm"], X).fit(cov_type="HC1")

    table_path = out_dir / "tables" / f"ols_{spec.y}_pm.txt"
    table_path.write_text(model.summary().as_text(), encoding="utf-8")

    fig_path = out_dir / "figures" / f"scatter_{spec.y}_pm_media_freedom.png"
    plt.figure(figsize=(9, 6))
    sns.regplot(
        data=g,
        x="media_freedom",
        y="y_pm",
        scatter_kws={"alpha": 0.4, "s": 18},
        line_kws={"color": "black"},
        ci=95,
    )
    plt.title(f"{spec.y_label} vs media freedom (baseline 2019)\n{TAKE_HOME_MESSAGE}")
    plt.xlabel("Media freedom (V-Dem v2x_freexp_altinf, baseline ≤2019)")
    plt.ylabel(f"{spec.y_label} per million (sum over window)")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()


def panel_event_slope(df: pd.DataFrame, spec: Spec, out_dir: Path, base_month: str, min_countries: int = 30) -> None:
    d = df.copy()
    d = d[d["media_freedom"].notna()].copy()
    d = d.dropna(subset=[spec.y_panel]).copy()

    d["month"] = pd.to_datetime(d["month"])
    d["month_id"] = d["month"].dt.to_period("M").astype(str)

    month_counts = d.groupby("month_id")["iso_code"].nunique().rename("n_countries")
    keep_months = month_counts[month_counts >= min_countries].index.tolist()
    d = d[d["month_id"].isin(keep_months)].copy()

    # Month-by-month slopes (more interpretable than dummy-FE interactions)
    controls = ["log_gdp_pc", "median_age", "hospital_beds_per_thousand"]
    controls = [c for c in controls if c in d.columns]

    months = sorted(d["month_id"].unique().tolist())
    rows = []
    for m in months:
        dm = d[d["month_id"] == m][[spec.y_panel, "media_freedom"] + controls].dropna().copy()
        if dm.shape[0] < min_countries:
            continue
        X = sm.add_constant(dm[["media_freedom"] + controls])
        fit = sm.OLS(dm[spec.y_panel], X).fit(cov_type="HC1")
        rows.append({"month": m, "coef": float(fit.params["media_freedom"]), "se": float(fit.bse["media_freedom"])})

    coef_df = pd.DataFrame(rows)
    if coef_df.empty:
        return
    coef_df["month"] = pd.PeriodIndex(coef_df["month"], freq="M").to_timestamp()
    coef_df["ci_lo"] = coef_df["coef"] - 1.96 * coef_df["se"]
    coef_df["ci_hi"] = coef_df["coef"] + 1.96 * coef_df["se"]

    coef_csv = out_dir / "tables" / f"panel_{spec.y_panel}_event_slope_coeffs.csv"
    coef_df.to_csv(coef_csv, index=False, encoding="utf-8")

    fig_path = out_dir / "figures" / f"event_slope_{spec.y_panel}.png"
    plt.figure(figsize=(10, 5))
    plt.plot(coef_df["month"], coef_df["coef"], color="black", linewidth=1.5)
    plt.fill_between(coef_df["month"], coef_df["ci_lo"], coef_df["ci_hi"], color="black", alpha=0.15)
    plt.axhline(0, color="gray", linewidth=1)
    plt.title(f"Monthly slope (adjusted): media freedom → {spec.y_label}")
    plt.xlabel("Month")
    plt.ylabel(f"Δ {spec.y_label} per +1 media freedom")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--panel", type=Path, default=Path("outputs/data/panel_merged.parquet"))
    parser.add_argument("--out-dir", type=Path, default=Path("outputs"))
    parser.add_argument("--base-month", type=str, default="2020-01")
    args = parser.parse_args()

    ensure_dirs(args.out_dir)
    df = pd.read_parquet(args.panel)

    specs = [
        Spec(
            y="gap_excess_minus_reported",
            y_label="Excess deaths - reported COVID deaths",
            y_panel="gap_pm_excess_minus_reported",
        ),
        Spec(
            y="excess_deaths",
            y_label="Excess deaths",
            y_panel="excess_deaths_pm",
        ),
    ]

    (args.out_dir / "tables" / "take_home_message.txt").write_text(TAKE_HOME_MESSAGE, encoding="utf-8")

    for spec in specs:
        ols_country_cross_section(df, spec, args.out_dir)
        panel_event_slope(df, spec, args.out_dir, args.base_month)

    print("Done.")


if __name__ == "__main__":
    main()
